{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tuning GPT-2 with LoRA and FHE using `LoraTrainer`\n",
    "\n",
    "This notebook demonstrates how to fine-tune a Llama-3.2-1B model using LoRA (Low-Rank Adaptation) with Fully Homomorphic Encryption (FHE). We leverage the `LoraTrainer` API from the `concrete.ml.torch.lora` library to simplify the process.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import shutil\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from peft import LoraConfig, get_peft_model\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    DataCollatorForLanguageModeling,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "from utils_lora import generate_and_print\n",
    "\n",
    "# Import LoraTrainer from the provided library\n",
    "from concrete.ml.torch.lora import LoraTrainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Concrete ML LoRA fine-tuning is implemented in a 'hybrid' setting: the client machine outsources all\n",
    "computations that involve the original model weights, but runs gradient descent on LoRA layers locally. \n",
    "\n",
    "The client machine thus executes some layers of the LoRA training protocol and it can use CPU or dedicated\n",
    "accelerators for this process. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original model linear layers execute with FHE on:  cpu\n",
      "Non-FHE layers and the LoRA weight optimizer executed on:  cpu\n"
     ]
    }
   ],
   "source": [
    "# Set seed for reproducibility\n",
    "SEED = 0\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "device = \"cpu\"\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(SEED)\n",
    "    device = \"cuda\"\n",
    "elif torch.backends.mps.is_available():\n",
    "    device = \"mps\"\n",
    "\n",
    "import concrete_ml_extensions as fhext\n",
    "\n",
    "cuda_fhext = fhext.is_cuda_enabled() and fhext.is_cuda_available()  # pylint: disable=no-member\n",
    "print(\n",
    "    \"Original model linear layers execute with FHE on: \",\n",
    "    \"cuda\" if cuda_fhext else \"cpu\",\n",
    ")\n",
    "print(\"Non-FHE layers and the LoRA weight optimizer executed on: \", device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set-up\n",
    "\n",
    "Load the LLAMA model, tokenize the dataset, and create LoRA fine-tuning configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b27cab98a5264b8f9e88b2c49d58c6ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "40da29f16b8a4ac8b7918ab33f67e634",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6538307fb3f14b8f9c06a63a23364c7e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56705155c40b49259f6135ec41874dae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b132d382240f4e10a57b5f24a10e7395",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "11bb1569c9644b92be27174865b2c189",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load the model and tokenizer\n",
    "model_name = \"meta-llama/Llama-3.2-1B\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
    "\n",
    "# Ensure the tokenizer has a pad token\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "\n",
    "# Freeze the original model's weights\n",
    "for param in model.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply LoRA configuration\n",
    "peft_config = LoraConfig(\n",
    "    r=8,\n",
    "    lora_alpha=32,\n",
    "    lora_dropout=0.01,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\",\n",
    "    target_modules=\"all-linear\",\n",
    ")\n",
    "peft_model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "75e4329d15e1410d84c0bd6face805e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d532c3b0382748dbbb01ff72fb40c016",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/46 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load the dataset and tokenize it\n",
    "dataset = load_dataset(\"json\", data_files=\"data_finetune/dataset.jsonl\", split=\"train\")\n",
    "\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding=\"longest\", truncation=True)\n",
    "\n",
    "\n",
    "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define training arguments\n",
    "EPOCHS = 10\n",
    "PER_DEVICE_TRAIN_BATCH_SIZE = 4\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./checkpoints\",\n",
    "    num_train_epochs=EPOCHS,\n",
    "    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,\n",
    "    gradient_accumulation_steps=1,\n",
    "    save_total_limit=1,\n",
    "    use_cpu=True,\n",
    "    learning_rate=2e-4,\n",
    "    lr_scheduler_type=\"linear\",\n",
    "    seed=SEED,\n",
    "    data_seed=SEED,\n",
    "    warmup_steps=10,\n",
    "    weight_decay=0.01,\n",
    "    prediction_loss_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create optimizer and scheduler using HuggingFace's Trainer\n",
    "hf_trainer = Trainer(\n",
    "    model=peft_model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_dataset,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "train_dataloader = hf_trainer.get_train_dataloader()\n",
    "hf_trainer.create_optimizer_and_scheduler(num_training_steps=len(train_dataloader) * EPOCHS)\n",
    "\n",
    "optimizer = hf_trainer.optimizer\n",
    "lr_scheduler = hf_trainer.lr_scheduler\n",
    "\n",
    "\n",
    "# Define a causal LM loss function\n",
    "def causal_lm_loss(logits, labels, ignore_index=-100):\n",
    "    shift_logits = logits[..., :-1, :].contiguous()\n",
    "    shift_labels = labels[..., 1:].contiguous()\n",
    "    shift_logits = shift_logits.view(-1, shift_logits.size(-1))\n",
    "    shift_labels = shift_labels.view(-1)\n",
    "    loss = torch.nn.functional.cross_entropy(\n",
    "        shift_logits, shift_labels, ignore_index=ignore_index, reduction=\"mean\"\n",
    "    )\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the original model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial generation with base model:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "from concrete.ml.sklearn import LogisticRegression\n",
      "\n",
      "model = LogisticRegression( eta=0.1, n_iter=1000, random_state=42)\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# Print the initial generation with the base model\n",
    "PROMPT = \"from concrete.ml.sklearn import LogisticRegression\\n\\nmodel = LogisticRegression(\"\n",
    "print(\"Initial generation with base model:\")\n",
    "print(generate_and_print(PROMPT, model, tokenizer, seed=SEED))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert the model to use FHE\n",
    "\n",
    "Similarily to all Concrete ML models, LoRA fine-tuning is set up using by compiling the\n",
    "model. For this, a representative set of data is required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LoRA layers detected in the model.\n"
     ]
    }
   ],
   "source": [
    "# Prepare input data for calibration\n",
    "lengths = [len(item[\"input_ids\"]) for item in tokenized_dataset]\n",
    "if not all(length == lengths[0] for length in lengths):\n",
    "    raise ValueError(\"All examples must have the same length for calibration.\")\n",
    "BLOCK_SIZE = lengths[0]\n",
    "\n",
    "input_tensor = torch.randint(\n",
    "    0, tokenizer.vocab_size, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long\n",
    ")\n",
    "label_tensor = torch.randint(\n",
    "    0, tokenizer.vocab_size, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long\n",
    ")\n",
    "attention_mask = torch.ones((PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long)\n",
    "inputset = {\"input_ids\": input_tensor, \"attention_mask\": attention_mask, \"labels\": label_tensor}\n",
    "\n",
    "# Initialize LoraTrainer\n",
    "training_args_dict = vars(training_args)\n",
    "lora_trainer = LoraTrainer(\n",
    "    model=peft_model,\n",
    "    optimizer=optimizer,\n",
    "    loss_fn=causal_lm_loss,\n",
    "    lr_scheduler=lr_scheduler,\n",
    "    training_args=training_args_dict,\n",
    "    n_layers_to_skip_for_backprop=3,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compile the model using quantization. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10e2266fe2ab402b8fa80cf7f56f7fcf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Compiling FHE layers:   0%|          | 0/221 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Compile the model with FHE\n",
    "lora_trainer.compile(inputset, n_bits=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test-run Concrete ML LoRA fine-tuning on clear data with quantization\n",
    "\n",
    "To check that everything works properly, it's possible to dry-run the fine-tuning on clear data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training using LoraTrainer...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bad7db6974f2410fba09237867744688",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/10 [00:00<?, ?epoch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training completed. Final Avg Loss: 0.0885, FHE Mode: disable\n"
     ]
    }
   ],
   "source": [
    "# Train the model using LoraTrainer\n",
    "print(\"Starting training using LoraTrainer...\")\n",
    "lora_trainer.train(train_dataloader, num_epochs=EPOCHS, fhe=\"disable\", device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "\n",
    "We show code generation using the original model versus the fine-tuned model. This is done\n",
    "by disabling the lora layers in the HybridFHEModel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original model generation:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "from concrete.ml.sklearn import LogisticRegression\n",
      "\n",
      "model = LogisticRegression( eta=0.1, max_iter=1000, random_state=1)\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# Compare generation before and after fine-tuning\n",
    "peft_model.disable_adapter_layers()\n",
    "print(\"Original model generation:\")\n",
    "print(generate_and_print(PROMPT, peft_model, tokenizer, seed=SEED))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fine-tuned model generation:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "from concrete.ml.sklearn import LogisticRegression\n",
      "\n",
      "model = LogisticRegression( eta=0.01, n_bits=8)\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "peft_model.enable_adapter_layers()\n",
    "print(\"Fine-tuned model generation:\")\n",
    "print(generate_and_print(PROMPT, peft_model, tokenizer, seed=SEED))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tuning on encrypted data\n",
    "\n",
    "Next, we benchmark the time to train on a single encrypted example, a \n",
    "code snippet of ~130 tokens. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a1caeea5bc048ecadebb249e20a55c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/46 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "FHE_SEQUENCE_LENGTH = 16\n",
    "\n",
    "\n",
    "def tokenize_function_fhe(examples):\n",
    "    return tokenizer(\n",
    "        examples[\"text\"], padding=\"max_length\", truncation=True, max_length=FHE_SEQUENCE_LENGTH\n",
    "    )\n",
    "\n",
    "\n",
    "tokenized_dataset = dataset.map(tokenize_function_fhe, batched=True)\n",
    "\n",
    "# Create a small data loader with a single example\n",
    "hf_trainer = Trainer(\n",
    "    model=peft_model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_dataset.select(list(range(PER_DEVICE_TRAIN_BATCH_SIZE))),\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "train_dataloader = hf_trainer.get_train_dataloader()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea0d8b9b367a4e08b6cfc96d11745b1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/1 [00:00<?, ?epoch/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training completed. Final Avg Loss: 13.6442, FHE Mode: execute\n",
      "Trained on one encrypted batch of 4 examples of 16 tokens in 5321.782980680466 seconds on cpu\n"
     ]
    }
   ],
   "source": [
    "# Execute fine-tuning, using the GPU when it is available\n",
    "fhe_epochs = 1\n",
    "import time\n",
    "\n",
    "start = time.time()\n",
    "lora_trainer.train(train_dataloader, num_epochs=fhe_epochs, fhe=\"execute\")\n",
    "duration = time.time() - start\n",
    "print(\n",
    "    (\n",
    "        f\"Trained on one encrypted batch of {PER_DEVICE_TRAIN_BATCH_SIZE} \"\n",
    "        f\"examples of {FHE_SEQUENCE_LENGTH} tokens in {duration} seconds on {device}\"\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the fine-tuned LoRA weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved to: deployment/llama_lora_finetuned\n"
     ]
    }
   ],
   "source": [
    "# Save the fine-tuned model\n",
    "save_path = Path(\"deployment/llama_lora_finetuned\")\n",
    "if save_path.is_dir() and any(save_path.iterdir()):\n",
    "    shutil.rmtree(save_path)\n",
    "lora_trainer.save_and_clear_private_info(save_path)\n",
    "\n",
    "print(\"Model saved to:\", save_path)"
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
